from nets.CreativePoints import *


def autopad(k, p=None, d=1):
	# kernel, padding, dilation
	# 对输入的特征层进行自动padding，按照Same原则
	if d > 1:
		# actual kernel-size
		k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
	if p is None:
		# auto-pad
		p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
	return p


class SiLU(nn.Module):
	# SiLU激活函数
	@staticmethod
	def forward(x):
		return x * torch.sigmoid(x)


class Conv(nn.Module):
	# 标准卷积+标准化+激活函数
	default_act = SiLU()
	
	def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
		super().__init__()
		self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
		self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
		self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
	
	def forward(self, x):
		return self.act(self.bn(self.conv(x)))
	
	def forward_fuse(self, x):
		return self.act(self.conv(x))


class P2CF(nn.Module):
	"""
	P2CF --> p方cf(Partial-Point-Channel-Fusion)
	"""
	
	# 标准瓶颈结构，残差结构
	# c1为输入通道数，c2为输出通道数
	def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
		super().__init__()
		c_ = int(c2 * e)  # hidden channels
		self.channel_shuffle = ShuffleAttentionV1.channel_shuffle
		self.pconv1 = PConv(c1, c_)
		self.spatial_attention = ShuffleAttentionV1(channel=c_, forward_state='forward_spatial')
		self.pwconv = PWConv(c1, c_, k=1, g=g)
		self.channel_attention = ShuffleAttentionV1(channel=c_, forward_state='forward_channel')
		self.bn = nn.BatchNorm2d(c_)
		self.relu = nn.ReLU(inplace=True)
		self.pconv2 = PConv(c_, c2)
		self.add = shortcut and c1 == c2
	
	def forward(self, x):
		b, c, h, w = x.size()
		
		x_spatial = self.spatial_attention(self.pconv1(x))
		x_channel = self.channel_attention(self.pwconv(x))
		# concatenate features
		out = torch.cat([x_channel, x_spatial], dim=1)
		out = out.contiguous().view(b, -1, h, w)
		# channel shuffle
		out = self.channel_shuffle(out, 2)
		# BN + ReLU
		out = self.relu(self.bn(out))
		
		if self.add:
			out = x + self.pconv2(out)
		else:
			out = self.pconv2(out)
		return out


class C2f(nn.Module):
	# CSPNet结构结构，大残差结构
	# c1为输入通道数，c2为输出通道数
	def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
		super().__init__()
		self.c = int(c2 * e)
		# self.cv1 = Conv(c1, 2 * self.c, 1, 1)
		self.cv1 = PWConv(c1, 2 * self.c, 1, 1)
		# self.cv2 = Conv((2 + n) * self.c, c2, 1)
		self.cv2 = PWConv((2 + n) * self.c, c2, 1)
		# self.m      = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
		self.m = nn.ModuleList(P2CF(self.c, self.c, shortcut, e=1.0) for _ in range(n))
	
	def forward(self, x):
		# 进行一个卷积，然后划分成两份，每个通道都为c
		y = list(self.cv1(x).split((self.c, self.c), 1))
		# 每进行一次残差结构都保留，然后堆叠在一起，密集残差
		y.extend(m(y[-1]) for m in self.m)
		return self.cv2(torch.cat(y, 1))


class SPPF(nn.Module):
	# SPP结构，5、9、13最大池化核的最大池化。
	def __init__(self, c1, c2, k=5):
		super().__init__()
		c_ = c1 // 2
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_ * 4, c2, 1, 1)
		self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
	
	def forward(self, x):
		x = self.cv1(x)
		y1 = self.m(x)
		y2 = self.m(y1)
		return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))


class Backbone(nn.Module):
	def __init__(self, base_channels, base_depth, deep_mul, phi, pretrained=False):
		super().__init__()
		# -----------------------------------------------#
		#   输入图片是3, 640, 640
		# -----------------------------------------------#
		# 3, 640, 640 => 32, 640, 640 => 64, 320, 320
		self.stem = Conv(3, base_channels, 3, 2)
		
		# 64, 320, 320 => 128, 160, 160 => 128, 160, 160
		self.dark2 = nn.Sequential(
			Conv(base_channels, base_channels * 2, 3, 2),
			C2f(base_channels * 2, base_channels * 2, base_depth, True),
		)
		# 128, 160, 160 => 256, 80, 80 => 256, 80, 80
		self.dark3 = nn.Sequential(
			Conv(base_channels * 2, base_channels * 4, 3, 2),
			C2f(base_channels * 4, base_channels * 4, base_depth * 2, True),
		)
		# 256, 80, 80 => 512, 40, 40 => 512, 40, 40
		self.dark4 = nn.Sequential(
			Conv(base_channels * 4, base_channels * 8, 3, 2),
			C2f(base_channels * 8, base_channels * 8, base_depth * 2, True),
		)
		# 512, 40, 40 => 1024 * deep_mul, 20, 20 => 1024 * deep_mul, 20, 20
		self.dark5 = nn.Sequential(
			Conv(base_channels * 8, int(base_channels * 16 * deep_mul), 3, 2),
			C2f(int(base_channels * 16 * deep_mul), int(base_channels * 16 * deep_mul), base_depth, True),
			SPPF(int(base_channels * 16 * deep_mul), int(base_channels * 16 * deep_mul), k=5)
		)
		
		if pretrained:
			url = {
				"n": 'SourceFile/pretrained/yolov8_l_backbone_weights.pth',
				"s": 'SourceFile/pretrained/yolov8_s_backbone_weights.pth',
				"m": 'SourceFile/pretrained/yolov8_m_backbone_weights.pth',
				"l": 'SourceFile/pretrained/yolov8_l_backbone_weights.pth',
				"x": 'SourceFile/pretrained/yolov8_x_backbone_weights.pth',
			}[phi]
			checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data")
			self.load_state_dict(checkpoint, strict=False)
			print("Load weights from " + url.split('/')[-1])
	
	def forward(self, x):
		x = self.stem(x)
		x = self.dark2(x)
		# -----------------------------------------------#
		#   dark3的输出为256, 80, 80，是一个有效特征层
		# -----------------------------------------------#
		x = self.dark3(x)
		feat1 = x
		# -----------------------------------------------#
		#   dark4的输出为512, 40, 40，是一个有效特征层
		# -----------------------------------------------#
		x = self.dark4(x)
		feat2 = x
		# -----------------------------------------------#
		#   dark5的输出为1024 * deep_mul, 20, 20，是一个有效特征层
		# -----------------------------------------------#
		x = self.dark5(x)
		feat3 = x
		return feat1, feat2, feat3
